import pickle
import numpy as np 
import random

import time

from sklearn import tree
from sklearn.calibration import CalibratedClassifierCV
from sklearn.ensemble import BaggingClassifier
from sklearn.utils import resample
from Utility import *
from GPcls import *
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.svm import SVC, SVR
from sklearn.linear_model import LogisticRegression, LinearRegression
import numpy as np 
import math
from scipy.stats import norm
from scipy.optimize import minimize_scalar


def SequentialPRT2(args, StatsPath, QueryPath, Data, ResonseRegionTestData):

	# search initial training set until there is at least two different labels 
	while(len(np.unique(Data[:args.SeqStartPoint, -1]))==1 and args.SeqStartPoint <=args.Budget): # make sure the training set include 2 classes
		args.SeqStartPoint+=1 
	TrSize =  args.SeqStartPoint;
	if args.SeqStartPoint< args.Budget:	
		if args.TestType == 'MatchPair_SignTest':
			start = time.time()
			Cov_Feat = int((Data.shape[1] - 5)/2) # label includes treatment and control responses, treatment and control assignment and the enrichment label
			# Construct initial training set for the betting classifier 
			InitDataforBetting = np.zeros((TrSize * 2 , Cov_Feat + 2))
			InitDataforBetting[: TrSize, :-2] = Data[:TrSize, : Cov_Feat] # assign iid covariates
			InitDataforBetting[TrSize : TrSize * 2, :-2] = Data[:TrSize, Cov_Feat : 2 * Cov_Feat] # assign paring covariates 
			InitDataforBetting[: TrSize, Cov_Feat] =  Data[: TrSize, 2 * Cov_Feat] # assign responses of the iid covariates 
			InitDataforBetting[TrSize : TrSize * 2, Cov_Feat] = Data[: TrSize, 2 * Cov_Feat + 1] # assign responses of the paring covariates 
			InitDataforBetting[: TrSize, -1] = Data[:TrSize, 2 * Cov_Feat + 2] # Assign treatment/control assignment for iid covariates 
			InitDataforBetting[TrSize : TrSize * 2, -1] = Data[: TrSize, 2 * Cov_Feat + 3] # Assign treatment/control assignment for pairing covariates  

			# Training with covariate pairs 
			InitDataforEnrichment = np.zeros((TrSize, Cov_Feat + 1))
			InitDataforEnrichment[:, :Cov_Feat] = Data[:TrSize, : Cov_Feat] # assign iid covariates
			# InitDataforEnrichment[:, Cov_Feat : 2 * Cov_Feat] = Data[:TrSize, Cov_Feat : 2 * Cov_Feat] # assign iid covariates
			InitDataforEnrichment[:, -1] = Data[:TrSize, -1] # assign responsive labels 

		
			# Get match pair covariate-reponse data 
			CovResponse = Data[:, : Cov_Feat * 2 + 2]
			# Get the treatment/control assignment and the enrichment labels for the covariates 
			Labels = Data[:, -3:]

			# Initialize classifiers 
			BettingCls = GetClassifier(args, args.Bettingcls, InitDataforBetting)		
			EnrichmentCls = GetClassifier(args, args.Enrichmentcls, InitDataforEnrichment,  ImbalanceProcess = False)

			Stats, QueryIndex, Enrichment_prob_All, Val_Selected_Points, EffectSize_along_labelspent, TPR_along_labelspent, Precision_along_labelspent, acc_along_labelspent  = MatchPairSignTest(args, BettingCls, EnrichmentCls, InitDataforBetting, InitDataforEnrichment, CovResponse, Labels, ResonseRegionTestData, ScalingStrategy=args.ScalingStrategy)
			end = time.time()	
			# debug
			if args.SaveEnrichmentProb == 1:
				with open(StatsPath + 'Enrichment_prob_All%d.pkl'%args.Trial, 'wb') as f:
					pickle.dump(Enrichment_prob_All, f)
				np.save(StatsPath + 'Val_Selected_Points%d.npy'%args.Trial, Val_Selected_Points)
				np.save(StatsPath + 'EffectSize_along_labelspent%d.npy'%args.Trial, EffectSize_along_labelspent)
				np.save(StatsPath + 'TPR_along_labelspent%d.npy'%args.Trial, TPR_along_labelspent)
				np.save(StatsPath + 'Precision_along_labelspent%d.npy'%args.Trial, Precision_along_labelspent)
				np.save(StatsPath + 'acc_along_labelspent%d.npy'%args.Trial, acc_along_labelspent)
		elif args.TestType == 'RegressionType_MatchPair_SignTest':
			start = time.time()
			Cov_Feat = int((Data.shape[1] - 5)/2) # label includes treatment and control responses, treatment and control assignment and the enrichment label
			# Construct initial training set for the betting classifier 
			InitDataforBetting = np.zeros((TrSize * 2 , Cov_Feat + 2))
			InitDataforBetting[: TrSize, :-2] = Data[:TrSize, : Cov_Feat] # assign iid covariates
			InitDataforBetting[TrSize : TrSize * 2, :-2] = Data[:TrSize, Cov_Feat : 2 * Cov_Feat] # assign paring covariates 
			InitDataforBetting[: TrSize, Cov_Feat] =  Data[: TrSize, 2 * Cov_Feat] # assign responses of the iid covariates 
			InitDataforBetting[TrSize : TrSize * 2, Cov_Feat] = Data[: TrSize, 2 * Cov_Feat + 1] # assign responses of the paring covariates 
			InitDataforBetting[: TrSize, -1] = Data[:TrSize, 2 * Cov_Feat + 2] # Assign treatment/control assignment for iid covariates 
			InitDataforBetting[TrSize : TrSize * 2, -1] = Data[: TrSize, 2 * Cov_Feat + 3] # Assign treatment/control assignment for paring covariates  

			# Construct initial training set for the enrichment classifier (Use regression functions as the base model)
			InitData = Data[:TrSize]
			IIDCovMask = InitData[:, 2 * Cov_Feat + 2]
			ParingCovMask = InitData[:, 2 * Cov_Feat + 3]
			
			InitDataforControlRegression = np.zeros((TrSize, Cov_Feat + 1))
			InitDataforTreatementRegression = np.zeros((TrSize, Cov_Feat + 1))

			# assign values to control response regression data
			InitDataforControlRegression[IIDCovMask == 0, : Cov_Feat] = InitData[IIDCovMask==0, :Cov_Feat]
			InitDataforControlRegression[ParingCovMask == 0, : Cov_Feat] = InitData[ParingCovMask==0, Cov_Feat: 2 * Cov_Feat]
			InitDataforControlRegression[IIDCovMask == 0, -1] = InitData[IIDCovMask==0, 2 * Cov_Feat]
			InitDataforControlRegression[ParingCovMask == 0, -1] = InitData[ParingCovMask==0, 2 * Cov_Feat + 1]

			# assign values to treatment response regression data
			InitDataforTreatementRegression[IIDCovMask == 1, : Cov_Feat] = InitData[IIDCovMask==1, :Cov_Feat]
			InitDataforTreatementRegression[ParingCovMask == 1, : Cov_Feat] = InitData[ParingCovMask==1, Cov_Feat: 2 * Cov_Feat]
			InitDataforTreatementRegression[IIDCovMask == 1, -1] = InitData[IIDCovMask==1, 2 * Cov_Feat]
			InitDataforTreatementRegression[ParingCovMask == 1, -1] = InitData[ParingCovMask==1, 2 * Cov_Feat + 1]
		
			# Get match pair covariate-reponse data 
			CovResponse = Data[:, : Cov_Feat * 2 + 2]
			# Get the treatment/control assignment and the enrichment labels for the covariates 
			Labels = Data[:, -3:]

			# Initialize classifiers 
			BettingCls = GetClassifier(args, args.Bettingcls, InitDataforBetting)
			ControlRegressor, TreatmentRegressor = GetRegressor(args, InitDataforControlRegression, InitDataforTreatementRegression)

			Stats, QueryIndex, Val_Selected_Points, EffectSize_along_labelspent, TPR_along_labelspent, Precision_along_labelspent, acc_along_labelspent  = RegressionType_MatchPair_SignTest(args, BettingCls, ControlRegressor, TreatmentRegressor, InitDataforBetting, InitDataforControlRegression, InitDataforTreatementRegression, CovResponse, Labels, ResonseRegionTestData, ScalingStrategy=args.ScalingStrategy)
			end = time.time()	
			# debug
			if args.SaveEnrichmentProb == 1:
				np.save(StatsPath + 'Val_Selected_Points%d.npy'%args.Trial, Val_Selected_Points)	
				np.save(StatsPath + 'EffectSize_Along_LabelSpent%d.npy'%args.Trial, EffectSize_along_labelspent)	
				np.save(StatsPath + 'TPR_along_labelspent%d.npy'%args.Trial, TPR_along_labelspent)	
				np.save(StatsPath + 'Precision_along_labelspent%d.npy'%args.Trial, Precision_along_labelspent)	
				np.save(StatsPath + 'acc_along_labelspent%d.npy'%args.Trial, acc_along_labelspent)

def ONS(v, w, a):
	v=-v
	z = v/(1 - v * w)
	a = a + z**2 
	w = min(0.5, max(0, w - 2 * z / ((2 - np.log(3)) * a)))
	return a, w
	

# Query by committee active learning. This function retain index of the disagreement and positive example
def QuerybyCommittee(Cls, Feat):
	PosInd = []; DisagreementInd = []

	ClsList = Cls.estimators_; 
	Predict_List = np.zeros((len(ClsList), len(Feat)))
	for i, one_cls in enumerate(ClsList):
		Predict_List[i] = one_cls.predict(Feat)
	AvgPredict = np.mean(Predict_List, axis = 0)
	for i in range(len(AvgPredict)):
		if AvgPredict[i] == 1:
			PosInd.append(i)
		elif AvgPredict[i] >0 and AvgPredict[i] < 1:
			DisagreementInd.append(i)
	return PosInd, DisagreementInd

# Active betting test for the match pair data
def MatchPairSignTest(args, BettingCls, EnrichmentCls, InitDataforBetting, InitDataforEnrichment, CovResponse, Labels, ResonseRegionTestData, ScalingStrategy):
	"""
	InitData: data that used to initialize the classifier 
	CovResponse: Cov-response for iid and pairing covariate
	Labels: Treatment/Control assignments encoded in the one-hot vector and the enrichment label 
	ResonseRegionTestData: Data that used to evaluate the enrichment classifier 
	"""

	# Initalize variables 
	Stats = np.zeros(5); # The first row stores reject or accept and 
						# the second row stores stopping time
						# the last three are the TPR, TFR and fscore of the enrichment classifier

	CovLen = int((CovResponse.shape[1] - 2)/2)
	QueryIndex = np.arange(args.SeqStartPoint).tolist() # index of queried label
	TrXBetting = InitDataforBetting[:, :-1]; TrYBetting = InitDataforBetting[:, -1].reshape((-1,1))
	TrXEnrichment = InitDataforEnrichment[:, :-1]; TrYEnrichment = InitDataforEnrichment[:, -1].reshape((-1,1))

	LogT = 0; UnslectedIndex = list(np.arange(args.SeqStartPoint, len(CovResponse))); Count = 0
	Reject = 0; 
	QueryNum = 0; 
	Enrichment_prob_All = {} # save the classifier's prediction probabilities every 50 query points 

	EffectSize_along_labelspent = np.zeros(len(UnslectedIndex)) # save the effect size along the label spent
	TPR_along_labelspent = np.zeros(len(UnslectedIndex)) # save the true positive rate along the labels spent
	acc_along_labelspent = np.zeros(len(UnslectedIndex)) # save the accuracy along the labels spent
	Precision_along_labelspent = np.zeros(len(UnslectedIndex)) # save the true positive rate along the labels spent

	# Active Betting 
	while(len(UnslectedIndex) > 0 and len(QueryIndex) <  args.Budget):	
		# Get unqueried iid covariate 
		IIDCovariates = CovResponse[UnslectedIndex, : CovLen]
		Enrichment_prob = EnrichmentCls.predict_proba(IIDCovariates.reshape((-1, CovLen)))[:, -1]
		# sample a index from region of positive + uncertain region 
		# Uncertainty-based active learning
		# ResponsiveInd = list(np.where(Enrichment_prob > args.Thres)[0]) 
		# Committee-voting based active learning
		pos_ind, disagree_ind = QuerybyCommittee(EnrichmentCls, IIDCovariates)
		if args.QS == 'Active':
			ResponsiveInd = pos_ind + disagree_ind
		elif args.QS == 'original_active':
			ResponsiveInd = pos_ind
		elif args.QS == 'Passive':
			ResponsiveInd = []
	
		# debug 
		if Count % 50 == 0:
			Enrichment_prob_All[int(Count / 50)] = Enrichment_prob
		

		if len(ResponsiveInd) == 0 or args.QS == 'Passive': # no qualified covariate pair or the query method is passive
			Ind = random.randint(0, len(UnslectedIndex) - 1)
			# PassiveCount+=1; print(PassiveCount) # for debug
			if args.QS != 'Passive': # save the effect size for the active query when the random sampling is needed in the active query
				if Count != 0:
					EffectSize_along_labelspent[Count] = EffectSize_along_labelspent[Count-1]
			ActiveQuery_Indicator = 0
		else: 
			Ind = random.choice(ResponsiveInd); ActiveQuery_Indicator = 1
	
		QInd = UnslectedIndex.pop(Ind)
		QueryIndex.append(QInd) # add query index
		QueryNum+=1

		# Get the testing iid covariate and its response 
		TestIIDCovRes= np.hstack((CovResponse[QInd][:CovLen].reshape((1, CovLen)), CovResponse[QInd][2 * CovLen].reshape((1, 1))))

		# Save the effect size for the queried point 
		if ActiveQuery_Indicator == 1:
			if Labels[QInd][1] == 1:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][2 * CovLen + 1] - CovResponse[QInd][2 * CovLen]
			else:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][2 * CovLen] - CovResponse[QInd][2 * CovLen + 1]
		if args.QS == 'Passive':
			if Labels[QInd][1] == 1:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][2 * CovLen + 1] - CovResponse[QInd][2 * CovLen]
			else:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][2 * CovLen] - CovResponse[QInd][2 * CovLen + 1]
	
		# Update the betting statistic 
		QueryProb0 = BettingCls.predict_proba(TestIIDCovRes)[0][0]

		# sign prediction function
		if QueryProb0 >0.5:
			PZ = -1; 
		else:
			PZ = 1; 

		# scaling factor 
		if ScalingStrategy == 'ONS':
			if len(QueryIndex) == args.SeqStartPoint + 1: # initialization
				w = 0; a = 1
			else:
				a, w = ONS(v, w, a)

		# Constrain w to [-1, 1]
		w = max(min(0.99999, w), -0.99999)
		# True signs
		if Labels[QInd][0] == 0:
			Z = -1
		else:
			Z = 1
		# payoff function
		v =  Z * PZ 

		LogT+= np.log(1 + w * v); 

		# Update training set for the betting
		NewTrXBetting = np.zeros((2, CovLen + 1))
		NewTrYBetting = np.zeros((2, 1))
		NewTrXBetting[0, :CovLen] = CovResponse[QInd][:CovLen]
		NewTrXBetting [0, CovLen] = CovResponse[QInd][2 * CovLen]
		NewTrXBetting [1, :CovLen] = CovResponse[QInd][CovLen : CovLen * 2]
		NewTrXBetting [1, CovLen] = CovResponse[QInd][2 * CovLen + 1]
		NewTrYBetting [0] = Labels[QInd][0]; 
		NewTrYBetting [1] = Labels[QInd][1]
		TrXBetting  = np.vstack((TrXBetting , NewTrXBetting))
		TrYBetting  = np.vstack((TrYBetting , NewTrYBetting))

		# Update training set for the enrichment 
		# Traini with one set of covariates 
		NewTrXEnrichment = np.zeros((1, CovLen))
		NewTrYEnrichment = np.zeros((1, 1))
		NewTrXEnrichment[0, :CovLen] = CovResponse[QInd][:CovLen]
		NewTrYEnrichment [0, 0] = Labels[QInd][-1]; 
		# Train with covaraite pairs 
		TrXEnrichment  = np.vstack((TrXEnrichment , NewTrXEnrichment))
		TrYEnrichment  = np.vstack((TrYEnrichment , NewTrYEnrichment))

		# Update betting and enrichment classifiers 
		BettingCls.fit(TrXBetting, TrYBetting.reshape(-1))	
		Trdata = np.hstack((TrXEnrichment,TrYEnrichment))
		EnrichmentCls = GetClassifier(args, args.Enrichmentcls, Trdata, Scale=1, ImbalanceProcess = False)
		
		if Reject == 0:
			if LogT >= np.log(1 / args.Alpha) and Reject == 0:
				Reject = 1; StopTime = len(QueryIndex); 
				if args.EarlyStopping == 1:
					break
			elif Reject == 0:
				StopTime = len(QueryIndex)
			else:
				break
		
		ResonseRegionTestData_X = ResonseRegionTestData[:, :  CovLen]
		ResonseRegionTestData_Y = ResonseRegionTestData[:, -1]

		pos_ind, disagree_ind = QuerybyCommittee( EnrichmentCls, ResonseRegionTestData_X)
		enrollment_ind = pos_ind + disagree_ind
		pred_enrollment = np.zeros(len(ResonseRegionTestData_Y)); pred_enrollment[enrollment_ind] = 1
		pred_postive = np.zeros(len(ResonseRegionTestData_Y)); pred_postive[pos_ind] = 1

		TPR, FPR, Precision, fscore, acc = GetEvalofCls(pred_enrollment, pred_postive, ResonseRegionTestData_Y)
		TPR_along_labelspent[Count] = TPR; 
		Precision_along_labelspent[Count] = Precision
		acc_along_labelspent[Count] = acc 
		Count+=1
		# print("TPR: %.5f, FPR: %.5f"%(TPR, FPR))
	# Save the statistics 
	ResonseRegionTestData_X = ResonseRegionTestData[:, :  CovLen]
	ResonseRegionTestData_Y = ResonseRegionTestData[:, -1]

	pos_ind, disagree_ind = QuerybyCommittee(EnrichmentCls, ResonseRegionTestData_X)
	enrollment_ind = pos_ind + disagree_ind
	pred_enrollment = np.zeros(len(ResonseRegionTestData_Y)); pred_enrollment[enrollment_ind] = 1
	pred_postive = np.zeros(len(ResonseRegionTestData_Y)); pred_postive[pos_ind] = 1

	TPR, FPR, Precision, fscore, acc = GetEvalofCls(pred_enrollment, pred_postive, ResonseRegionTestData_Y)
	Stats[0] = Reject ; Stats[1] = StopTime; # Stopping time and rejection
	Stats[2] = TPR; Stats[3] = FPR; Stats[4] = fscore

	# select points from the validation set 
	Val_Enrichment_prob = EnrichmentCls.predict_proba(ResonseRegionTestData[:, : CovLen])[:, -1]

	print("Reject:%d, Stopping num: %d, TPR: %.5f, FPR: %.5f, acc:%.5f"%(Reject, StopTime, TPR, FPR, acc))
	Val_Selected_Points = ResonseRegionTestData[Val_Enrichment_prob >  args.Thres]
	return Stats, QueryIndex, Enrichment_prob_All, Val_Selected_Points, EffectSize_along_labelspent, TPR_along_labelspent, Precision_along_labelspent, acc_along_labelspent

# Active betting test for the match pair data
def RegressionType_MatchPair_SignTest(args, BettingCls, ControlRegressor, TreatmentRegressor, InitDataforBetting, InitDataforControlRegression, InitDataforTreatementRegression, CovResponse, Labels, ResonseRegionTestData, ScalingStrategy):
	"""
	InitData: data that used to initialize the classifier 
	CovResponse: Cov-response for iid and pairing covariate
	Labels: Treatment/Control assignments encoded in the one-hot vector and the enrichment label 
	ResonseRegionTestData: Data that used to evaluate the enrichment classifier 
	"""

	# Initalize variables 
	Stats = np.zeros(5); # The first row stores reject or accept and 
						# the second row stores stopping time
						# the last three are the TPR, TFR and fscore of the enrichment classifier

	CovLen = int((CovResponse.shape[1] - 2)/2)
	QueryIndex = np.arange(args.SeqStartPoint).tolist() # index of queried label
	TrXBetting = InitDataforBetting[:, :-1]; TrYBetting = InitDataforBetting[:, -1].reshape((-1,1))
	TrXTreatment = InitDataforTreatementRegression[:, :-1]; TrYTreatment = InitDataforTreatementRegression[:, -1].reshape((-1,1))
	TrXControl = InitDataforControlRegression[:, :-1]; TrYControl = InitDataforControlRegression[:, -1].reshape((-1,1))

	LogT = 0; UnslectedIndex = list(np.arange(args.SeqStartPoint, len(CovResponse)))
	Reject = 0; 
	QueryNum = 0; 
	EffectSize_along_labelspent = np.zeros(len(UnslectedIndex)) # save the effect size along the label spent
	TPR_along_labelspent = np.zeros(len(UnslectedIndex)) # save the True positive rate along the label spent
	Precision_along_labelspent =  np.zeros(len(UnslectedIndex)) 
	acc_along_labelspent = np.zeros(len(UnslectedIndex)) # save the accuracy
	Count = 0
	# Active Betting 
	while(len(UnslectedIndex) > 0 and len(QueryIndex) <  args.Budget):	
		# Get unqueried iid covariate 
		IIDCovariates = CovResponse[UnslectedIndex, :CovLen]
		Treatment_Prediction = TreatmentRegressor.predict(IIDCovariates.reshape((-1, CovLen)))
		Control_Prediction = ControlRegressor.predict(IIDCovariates.reshape((-1, CovLen)))
		Predic_Effect = Treatment_Prediction - 	Control_Prediction	
		ResponsiveInd = np.where(Predic_Effect > args.RespondingThres)[0]
	
		if len(ResponsiveInd) == 0 or args.QS == 'Passive': # no qualified covariate pair or the query method is passive
			Ind = random.randint(0, len(UnslectedIndex) - 1)
			ActiveQueryIndicator = 0
			if args.QS != 'Passive': # Save the effect size of the active query when the random query is used in the active qeury
				if Count != 0:
					EffectSize_along_labelspent[Count] = EffectSize_along_labelspent[Count - 1]
		else: 
			Ind = random.choice(ResponsiveInd)
			ActiveQueryIndicator = 1
		QInd = UnslectedIndex.pop(Ind)
		QueryIndex.append(QInd) # add query index
		QueryNum+=1

		if ActiveQueryIndicator == 1:
			if Labels[QInd][1] == 1:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][2 * CovLen + 1] - CovResponse[QInd][2 * CovLen]
			else:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][2 * CovLen] - CovResponse[QInd][2 * CovLen + 1]
		
		# Get the testing iid covariate and its response 
		TestIIDCovRes= np.hstack((CovResponse[QInd][:CovLen].reshape((1, CovLen)), CovResponse[QInd][2 * CovLen].reshape((1, 1))))

		# Update the betting statistic 
		QueryProb0 = BettingCls.predict_proba(TestIIDCovRes)[0][0]


		# sign prediction function
		if QueryProb0 >0.5:
			PZ = -1; 
		else:
			PZ = 1; 

		# scaling factor 
		if ScalingStrategy == 'ONS':
			if len(QueryIndex) == args.SeqStartPoint + 1: # initialization
				w = 0; a = 1
			else:
				a, w = ONS(v, w, a)

		# Constrain w to [-1, 1]
		w = max(min(0.99999, w), -0.99999)
		# True signs
		if Labels[QInd][0] == 0:
			Z = -1
		else:
			Z = 1
		# payoff function
		v =  Z * PZ 

		LogT+= np.log(1 + w * v); 

		# Update training set for the betting
		NewTrXBetting = np.zeros((2, CovLen + 1))
		NewTrYBetting = np.zeros((2, 1))
		NewTrXBetting[0, :CovLen] = CovResponse[QInd][:CovLen]
		NewTrXBetting [0, CovLen] = CovResponse[QInd][2 * CovLen]
		NewTrXBetting [1, :CovLen] = CovResponse[QInd][CovLen : CovLen * 2]
		NewTrXBetting [1, CovLen] = CovResponse[QInd][2 * CovLen + 1]
		NewTrYBetting [0] = Labels[QInd][0]; 
		NewTrYBetting [1] = Labels[QInd][1]
		TrXBetting  = np.vstack((TrXBetting , NewTrXBetting))
		TrYBetting  = np.vstack((TrYBetting , NewTrYBetting))

		# Update training set for the two regression functions  
		NewTrXTreatment = np.zeros((1, CovLen)); NewTrXControl = np.zeros((1, CovLen))
		NewTrYTreatment = np.zeros((1, 1)); NewTrYControl = np.zeros((1, 1))
		if Labels[QInd, 0] == 0:
			NewTrXControl[0, :CovLen] = CovResponse[QInd][:CovLen]
			NewTrYControl[0, -1] = CovResponse[QInd][2 * CovLen]	
			NewTrXTreatment[0, :CovLen] = CovResponse[QInd][CovLen : 2 * CovLen]
			NewTrYTreatment[0, -1] = CovResponse[QInd][2 * CovLen + 1]	
		else:	
			NewTrXControl[0, :CovLen] = CovResponse[QInd][CovLen :2 * CovLen]
			NewTrYControl[0, -1] = CovResponse[QInd][2 * CovLen + 1]	
			NewTrXTreatment[0, :CovLen] = CovResponse[QInd][: CovLen]
			NewTrYTreatment[0, -1] = CovResponse[QInd][2 * CovLen]	

		TrXTreatment  = np.vstack((TrXTreatment , NewTrXTreatment))
		TrYTreatment  = np.vstack((TrYTreatment , NewTrYTreatment))
		TrXControl  = np.vstack((TrXControl , NewTrXControl))
		TrYControl  = np.vstack((TrYControl , NewTrYControl))

		# Update betting and enrichment classifiers 
		BettingCls.fit(TrXBetting, TrYBetting.reshape(-1))	
		ControlRegressor.fit(TrXControl, TrYControl.reshape(-1))	
		TreatmentRegressor.fit(TrXTreatment, TrYTreatment.reshape(-1))	
		if Reject == 0:
			if LogT >= np.log(1 / args.Alpha) and Reject == 0:
				Reject = 1; StopTime = len(QueryIndex); 
				if args.EarlyStopping == 1:
					break
			elif Reject == 0:
				StopTime = len(QueryIndex)
			else:
				break
		
		# Calculate the TPR
		ResonseRegionTestData_X = ResonseRegionTestData[:, :CovLen]
		ResonseRegionTestData_Y = ResonseRegionTestData[:, -1]
		Test_predict_Effectsize = TreatmentRegressor.predict(ResonseRegionTestData_X) - ControlRegressor.predict(ResonseRegionTestData_X)
		Pred_Label = Test_predict_Effectsize > args.RespondingThres
		TPR, FPR, Precision, fscore, acc = GetEvalofCls(Pred_Label, Pred_Label, ResonseRegionTestData_Y)
		TPR_along_labelspent[Count] = TPR; 
		Precision_along_labelspent[Count] = Precision
		acc_along_labelspent[Count]
		Count+=1
		# print("TPR: %.5f, FPR: %.5f"%(TPR, FPR))
	# Save the statistics 
	ResonseRegionTestData_X = ResonseRegionTestData[:, :CovLen]
	ResonseRegionTestData_Y = ResonseRegionTestData[:, -1]
	Test_predict_Effectsize = TreatmentRegressor.predict(ResonseRegionTestData_X) - ControlRegressor.predict(ResonseRegionTestData_X)
	Pred_Label = Test_predict_Effectsize > args.RespondingThres
	TPR, FPR, Precision, fscore, acc = GetEvalofCls(Pred_Label, Pred_Label, ResonseRegionTestData_Y)
	Stats[0] = Reject ; Stats[1] = StopTime; # Stopping time and rejection
	Stats[2] = TPR; Stats[3] = FPR; Stats[4] = fscore

	# debug; acquire the selected points from the validation set
	Val_Treatment_Prediction = TreatmentRegressor.predict(ResonseRegionTestData[:, :CovLen])
	Val_Control_Prediction = ControlRegressor.predict(ResonseRegionTestData[:, :CovLen])
	ValPredic_Effect = Val_Treatment_Prediction - 	Val_Control_Prediction	
	ValResponsiveInd = np.where(ValPredic_Effect > args.RespondingThres)[0]
	Val_Selected_Points = ResonseRegionTestData[ValResponsiveInd]
	print("Reject:%d, Stopping num: %d, TPR: %.5f, FPR: %.5f"%(Reject, StopTime, TPR, FPR))
	return Stats, QueryIndex, Val_Selected_Points, EffectSize_along_labelspent, TPR_along_labelspent, Precision_along_labelspent, acc_along_labelspent

	"""
	InitData: data that used to initialize the classifier 
	CovResponse: Cov-response for iid and pairing covariate
	Labels: Treatment/Control assignments encoded in the one-hot vector and the enrichment label 
	ResonseRegionTestData: Data that used to evaluate the enrichment classifier 
	"""

	# Initalize variables 
	Stats = np.zeros(5); # The first row stores reject or accept and 
						# the second row stores stopping time
						# the last three are the TPR, TFR and fscore of the enrichment classifier

	CovLen = CovResponse.shape[1] - 2
	QueryIndex = np.arange(args.SeqStartPoint).tolist() # index of queried label
	TrXBetting = InitDataforBetting[:, :-1]; TrYBetting = InitDataforBetting[:, -1].reshape((-1,1))
	TrXTreatment = InitDataforTreatementRegression[:, :-1]; TrYTreatment = InitDataforTreatementRegression[:, -1].reshape((-1,1))
	TrXControl = InitDataforControlRegression[:, :-1]; TrYControl = InitDataforControlRegression[:, -1].reshape((-1,1))

	LogT = 0; UnslectedIndex = list(np.arange(args.SeqStartPoint, len(CovResponse)))
	Reject = 0; 
	QueryNum = 0; 
	EffectSize_along_labelspent = np.zeros(len(UnslectedIndex)) # save the effect size along the label spent
	TPR_along_labelspent = np.zeros(len(UnslectedIndex)) # save the True positive rate along the label spent
	Precision_along_labelspent =  np.zeros(len(UnslectedIndex)) 
	Count = 0
	# Active Betting 
	while(len(UnslectedIndex) > 0 and len(QueryIndex) <  args.Budget):	
		# Get unqueried iid covariate 
		IIDCovariates = CovResponse[UnslectedIndex, :CovLen]
		Treatment_Prediction = TreatmentRegressor.predict(IIDCovariates.reshape((-1, CovLen)))
		Control_Prediction = ControlRegressor.predict(IIDCovariates.reshape((-1, CovLen)))
		Predic_Effect = Treatment_Prediction - 	Control_Prediction	
		ResponsiveInd = np.where(Predic_Effect > args.RespondingThres)[0]
	
		if len(ResponsiveInd) == 0 or args.QS == 'Passive': # no qualified covariate pair or the query method is passive
			Ind = random.randint(0, len(UnslectedIndex) - 1)
			ActiveQueryIndicator = 0
			if args.QS != 'Passive': # Save the effect size of the active query when the random query is used in the active qeury
				if Count != 0:
					EffectSize_along_labelspent[Count] = EffectSize_along_labelspent[Count - 1]
		else: 
			Ind = random.choice(ResponsiveInd)
			ActiveQueryIndicator = 1
		QInd = UnslectedIndex.pop(Ind)
		QueryIndex.append(QInd) # add query index
		QueryNum+=1

		if ActiveQueryIndicator == 1:
			if Labels[QInd][1] == 1:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][CovLen + 1] - CovResponse[QInd][CovLen]
			else:
				EffectSize_along_labelspent[Count] = CovResponse[QInd][CovLen] - CovResponse[QInd][CovLen + 1]
		
		# Get the testing iid covariate and its response 
		TestIIDCovRes= np.hstack((CovResponse[QInd][:CovLen].reshape((1, CovLen)), CovResponse[QInd][CovLen].reshape((1, 1))))

		# Update the betting statistic 
		QueryProb0 = BettingCls.predict_proba(TestIIDCovRes)[0][0]


		# sign prediction function
		if QueryProb0 >0.5:
			PZ = -1; 
		else:
			PZ = 1; 

		# scaling factor 
		if ScalingStrategy == 'MCE':
			if QueryProb0 >0.5:
				w = 2 * QueryProb0 - 1 # scaling factor for the payoff function
			else:
				w = 1 - 2 * QueryProb0 # scaling factor for the payoff function
		elif ScalingStrategy == 'ONS':
			if len(QueryIndex) == args.SeqStartPoint + 1: # initialization
				w = 0; a = 1
			else:
				a, w = ONS(v, w, a)
		elif ScalingStrategy == 'Random':
			w = random.uniform(-1, 1)
		elif ScalingStrategy == 'LowerBound' and QueryNum == 1:
			w = minimize_scalar(MinimizeArea, args=(args.Alpha, 30, args.Budget), bounds=(0, 1), method='Bounded').x

		# Constrain w to [-1, 1]
		w = max(min(0.99999, w), -0.99999)
		# True signs
		if Labels[QInd][0] == 0:
			Z = -1
		else:
			Z = 1
		# payoff function
		v =  Z * PZ 

		LogT+= np.log(1 + w * v); 

		# Update training set for the betting
		NewTrXBetting = np.zeros((2, CovLen + 1))
		NewTrYBetting = np.zeros((2, 1))
		NewTrXBetting[0, :CovLen] = CovResponse[QInd][:CovLen]
		NewTrXBetting [0, CovLen] = CovResponse[QInd][CovLen]
		NewTrXBetting [1, :CovLen] = CovResponse[QInd][:CovLen]
		NewTrXBetting [1, CovLen] = CovResponse[QInd][CovLen + 1]
		NewTrYBetting [0] = Labels[QInd][0]; 
		NewTrYBetting [1] = Labels[QInd][1]
		TrXBetting  = np.vstack((TrXBetting , NewTrXBetting))
		TrYBetting  = np.vstack((TrYBetting , NewTrYBetting))

		# Update training set for the two regression functions  
		NewTrXTreatment = np.zeros((1, CovLen)); NewTrXControl = np.zeros((1, CovLen))
		NewTrYTreatment = np.zeros((1, 1)); NewTrYControl = np.zeros((1, 1))
		if Labels[QInd, 0] == 0:
			NewTrXControl[0, :CovLen] = CovResponse[QInd][:CovLen]
			NewTrYControl[0, -1] = CovResponse[QInd][CovLen]	
			NewTrXTreatment[0, :CovLen] = CovResponse[QInd][:CovLen]
			NewTrYTreatment[0, -1] = CovResponse[QInd][CovLen + 1]	
		else:	
			NewTrXControl[0, :CovLen] = CovResponse[QInd][: CovLen]
			NewTrYControl[0, -1] = CovResponse[QInd][CovLen + 1]	
			NewTrXTreatment[0, :CovLen] = CovResponse[QInd][: CovLen]
			NewTrYTreatment[0, -1] = CovResponse[QInd][CovLen]	

		TrXTreatment  = np.vstack((TrXTreatment , NewTrXTreatment))
		TrYTreatment  = np.vstack((TrYTreatment , NewTrYTreatment))
		TrXControl  = np.vstack((TrXControl , NewTrXControl))
		TrYControl  = np.vstack((TrYControl , NewTrYControl))

		# Update betting and enrichment classifiers 
		BettingCls.fit(TrXBetting, TrYBetting.reshape(-1))	
		ControlRegressor.fit(TrXControl, TrYControl.reshape(-1))	
		TreatmentRegressor.fit(TrXTreatment, TrYTreatment.reshape(-1))	
		if Reject == 0:
			if LogT >= np.log(1 / args.Alpha) and Reject == 0:
				Reject = 1; StopTime = len(QueryIndex); 
				if args.EarlyStopping == 1:
					break
			elif Reject == 0:
				StopTime = len(QueryIndex)
			else:
				break
		
		# Calculate the TPR
		ResonseRegionTestData_X = ResonseRegionTestData[:, :CovLen]
		ResonseRegionTestData_Y = ResonseRegionTestData[:, -1]
		Test_predict_Effectsize = TreatmentRegressor.predict(ResonseRegionTestData_X) - ControlRegressor.predict(ResonseRegionTestData_X)
		Pred_Label = Test_predict_Effectsize > args.RespondingThres
		TPR, FPR, Precision, fscore = GetEvalofCls(Pred_Label, ResonseRegionTestData_Y)
		TPR_along_labelspent[Count] = TPR; 
		Precision_along_labelspent[Count] = Precision
		Count+=1
		# print("TPR: %.5f, FPR: %.5f"%(TPR, FPR))
	# Save the statistics 
	ResonseRegionTestData_X = ResonseRegionTestData[:, :CovLen]
	ResonseRegionTestData_Y = ResonseRegionTestData[:, -1]
	Test_predict_Effectsize = TreatmentRegressor.predict(ResonseRegionTestData_X) - ControlRegressor.predict(ResonseRegionTestData_X)
	Pred_Label = Test_predict_Effectsize > args.RespondingThres
	TPR, FPR, Precision, fscore = GetEvalofCls(Pred_Label, ResonseRegionTestData_Y)
	Stats[0] = Reject ; Stats[1] = StopTime; # Stopping time and rejection
	Stats[2] = TPR; Stats[3] = FPR; Stats[4] = fscore

	# debug; acquire the selected points from the validation set
	Val_Treatment_Prediction = TreatmentRegressor.predict(ResonseRegionTestData[:, :CovLen])
	Val_Control_Prediction = ControlRegressor.predict(ResonseRegionTestData[:, :CovLen])
	ValPredic_Effect = Val_Treatment_Prediction - 	Val_Control_Prediction	
	ValResponsiveInd = np.where(ValPredic_Effect > args.RespondingThres)[0]
	Val_Selected_Points = ResonseRegionTestData[ValResponsiveInd]
	print("Reject:%d, Stopping num: %d, TPR: %.5f, FPR: %.5f"%(Reject, StopTime, TPR, FPR))
	return Stats, QueryIndex, Val_Selected_Points, EffectSize_along_labelspent, TPR_along_labelspent, Precision_along_labelspent

def GetClassifier(args, classifier_name, Trdata, Scale=1, ImbalanceProcess = False):
	"""
	Classifier for searching the enrichment region
	"""
	# BoostTrData = OverSample(Trdata) # imbalance handle
	# X = BoostTrData[:, :-1]; Y = BoostTrData[:, -1]
	X = Trdata[:, :-1]; Y = Trdata[:, -1].reshape(-1)

	if classifier_name=='logistic':
		Cls = LogisticRegression(random_state=args.Trial, warm_start=True); 
		# Cls = LogisticRegression(random_state=args.Trial, warm_start=True); 
		# Cls.coef_ = np.array(0.0000001 * np.ones((1, X.shape[-1])))  # Shape (1, n_features)
		# Cls.intercept_ = np.array([0.0000001])     # Shape (1,)
	elif classifier_name == 'Ensemble_LogisticRegression':
		BaseCls = LogisticRegression(random_state=args.Trial, warm_start=True); 
		Cls = BaggingClassifier(estimator=BaseCls, random_state=args.Trial, n_estimators=args.classifier_num, warm_start=True)
	elif classifier_name == 'Ensemble_knn':
		BaseCls = KNeighborsClassifier(algorithm='auto', n_neighbors=5); 
		Cls = BaggingClassifier(estimator=BaseCls, random_state=args.Trial,  n_estimators=args.classifier_num, warm_start=True)
	elif classifier_name == 'Ensemble_DecisionTree':
		Cls = BaggingClassifier(random_state=args.Trial,  n_estimators=args.classifier_num, warm_start=True)
	elif classifier_name == 'Ensemble_SVC':
		BaseCls = SVC(gamma='auto', kernel='rbf', random_state=args.Trial, probability = True); 
		Cls = BaggingClassifier(estimator=BaseCls, random_state=args.Trial,  n_estimators=args.classifier_num, warm_start=True)
	elif classifier_name == 'SVC':
		BaseCls = SVC(gamma='auto', kernel='rbf', random_state=args.Trial, probability = True); 
		# Cls = BaseCls
		Cls = CalibratedClassifierCV(BaseCls, cv=3)
	elif classifier_name == 'knn':
		# Cls = KNeighborsClassifier(algorithm='auto', n_neighbors=math.ceil(len(Trdata)**(args.gr))); 
		Cls = KNeighborsClassifier(algorithm='auto', n_neighbors=5); 
	elif classifier_name =='NN':
		Cls = MLPClassifier(solver='adam', alpha=1e-5, hidden_layer_sizes=(20,), max_iter=500, random_state=args.Trial)
	elif classifier_name == 'DecisionTree':
		BaseCls = tree.DecisionTreeClassifier()
		Cls = CalibratedClassifierCV(BaseCls, cv=5)
	if ImbalanceProcess:
		X0 = X[Y==0]; X1 = X[Y==1]
		Y0 = Y[Y==0]; Y1 = Y[Y==1]
		if len(X0) <len(X1):
			minority_X = X0; majority_X=X1
			minority_Y = Y0; majority_Y=Y1
		else:
			minority_X = X1; majority_X=X0
			minority_Y = Y1; majority_Y=Y0
		resample_id = resample(list(np.arange(len(minority_X))), replace=True, n_samples=len(majority_X), random_state=args.Trial)
		minority_X_resample = minority_X[resample_id]; 
		minority_Y_resample = minority_Y[resample_id]
		X = np.vstack((majority_X, minority_X_resample)); Y = np.vstack((majority_Y.reshape((-1, 1)), minority_Y_resample.reshape((-1, 1))))
	Cls.fit(X,Y.reshape(-1)); 

	return Cls

def GetRegressor(args, ControlData, TreatmentData):
	"""
	Regressor for searching the enrichment region 
	"""
	TreatmentX = TreatmentData[:, :-1]; TreatmentY = TreatmentData[:, -1]
	ControlX = ControlData[:, :-1]; ControlY = ControlData[:, -1]

	if args.regressor=='GP':
		kernel = DotProduct() + WhiteKernel()
		TreatmentRegressor = GaussianProcessRegressor(kernel=kernel, alpha=0.001); 
		ControlRegressor = GaussianProcessRegressor(kernel=kernel, alpha=0.001); 
	elif args.regressor == 'linear':
		TreatmentRegressor = LinearRegression()
		ControlRegressor = LinearRegression()
	elif args.regressor == 'NN':
		TreatmentRegressor = MLPRegressor(solver='adam', alpha=1e-5, hidden_layer_sizes=(20,), max_iter=500, random_state=args.Trial)
		ControlRegressor = MLPRegressor(solver='adam', alpha=1e-5, hidden_layer_sizes=(20,), max_iter=500, random_state=args.Trial)
	elif args.regressor == 'DecisionTree':
		TreatmentRegressor = tree.DecisionTreeRegressor()
		ControlRegressor = tree.DecisionTreeRegressor()
	elif args.regressor == 'SVC':
		TreatmentRegressor = SVR(gamma='auto', kernel='rbf'); 
		ControlRegressor = SVR(gamma='auto', kernel='rbf'); 
	elif args.regressor == 'knn':
		TreatmentRegressor = KNeighborsRegressor(n_neighbors=5)
		ControlRegressor = KNeighborsRegressor(n_neighbors=5)
	TreatmentRegressor.fit(TreatmentX, TreatmentY); 
	ControlRegressor.fit(ControlX, ControlY)

	return ControlRegressor, TreatmentRegressor









